{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "62549efe",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# The Prompt\n",
    "\n",
    "Let's see how to customize the prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "79fd8e41",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "nbsphinx": "hidden",
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": [
     "remove_cell"
    ]
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.insert(0, \"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d3a2212-38fa-4ab1-bec2-e1f4aeb8af47",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import langchain\n",
    "from kork import (\n",
    "    CodeChain,\n",
    "    ast,\n",
    "    AstPrinter,\n",
    "    c_,\n",
    "    r_,\n",
    "    run_interpreter,\n",
    ")\n",
    "from langchain import PromptTemplate\n",
    "from kork import SimpleContextRetriever"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fb100da-29d9-442b-b89c-4b60bbc1967b",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "The code below defines a place holder model for a chat model.\n",
    "Feel free to replace it with a real language model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3c5e0604-b50c-470d-855c-9b5253206a81",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "from typing import Any, List, Optional\n",
    "\n",
    "from langchain.chat_models.base import BaseChatModel\n",
    "from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult\n",
    "from pydantic import Extra\n",
    "\n",
    "\n",
    "class ToyChatModel(BaseChatModel):\n",
    "    response: str\n",
    "\n",
    "    class Config:\n",
    "        \"\"\"Configuration for this pydantic object.\"\"\"\n",
    "\n",
    "        extra = Extra.forbid\n",
    "        arbitrary_types_allowed = True\n",
    "\n",
    "    def _generate(\n",
    "        self, messages: List[BaseMessage], stop: Optional[List[str]] = None\n",
    "    ) -> ChatResult:\n",
    "        message = AIMessage(content=self.response)\n",
    "        generation = ChatGeneration(message=message)\n",
    "        return ChatResult(generations=[generation])\n",
    "\n",
    "    async def _agenerate(\n",
    "        self, messages: List[BaseMessage], stop: Optional[List[str]] = None\n",
    "    ) -> Any:\n",
    "        \"\"\"Async version of _generate.\"\"\"\n",
    "        message = AIMessage(content=self.response)\n",
    "        generation = ChatGeneration(message=message)\n",
    "        return ChatResult(generations=[generation])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07c73183-ffd6-450c-abe0-685e89ca5054",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Custom prompt\n",
    "\n",
    "Prompts take two optional input variables: the `language_name` and the `external_funcitons_block`. If the prompt uses such a variable, the variable will be auto-populated by the chain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bd679388-6320-4933-aa1c-17ae8b719289",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "instruction_template = PromptTemplate(\n",
    "    template=\"\"\"\n",
    "You are coding in a language called the `{language_name}`.\n",
    "\n",
    "\n",
    "\n",
    "{external_functions_block}\n",
    "\n",
    "Begin!\\n\n",
    "\"\"\",\n",
    "    input_variables=[\"language_name\", \"external_functions_block\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab4d3808-50ab-4ade-aab7-677e23624cab",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Declare the chain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e2df25f-169d-4961-9fd9-4b10bb9f07c9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "chain = CodeChain.from_defaults(\n",
    "    llm=ToyChatModel(response=\"MEOW MEOW MEOW MEOW\"),  # The LLM to use\n",
    "    examples=[\n",
    "        (\"2**5\", r_(c_(math.pow, 2, 5))),\n",
    "        (\"take the log base 2 of 2\", r_(c_(math.log2, 2))),\n",
    "    ],  # Example programs\n",
    "    context=[math.pow, math.log2, math.log10],\n",
    "    instruction_template=instruction_template,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7d92f7b6-d1f5-41b5-826b-b338edc3831e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "environment, few_shot_prompt = chain.prepare_context(query=\"hello\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e72e5f9e-fe2d-4f6b-81c1-52c89ecd032e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ExternFunctionDef(name='pow', params=ParamList(params=[Param(name='x', type_='Any'), Param(name='y', type_='Any')]), return_type='Any', implementation=<built-in function pow>, doc_string='Return x**y (x to the power of y).'),\n",
       " ExternFunctionDef(name='log2', params=ParamList(params=[Param(name='x', type_='Any')]), return_type='Any', implementation=<built-in function log2>, doc_string='Return the base 2 logarithm of x.'),\n",
       " ExternFunctionDef(name='log10', params=ParamList(params=[Param(name='x', type_='Any')]), return_type='Any', implementation=<built-in function log10>, doc_string='Return the base 10 logarithm of x.')]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "environment.list_external_functions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "06b5f073-bb8c-40a5-a816-52075417dee2",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "You are coding in a language called the `😼`.\n",
      "\n",
      "\n",
      "\n",
      "You have access to the following external functions:\n",
      "\n",
      "```😼\n",
      "extern fn pow(x: Any, y: Any) -> Any // Return x**y (x to the power of y).\n",
      "extern fn log2(x: Any) -> Any // Return the base 2 logarithm of x.\n",
      "extern fn log10(x: Any) -> Any // Return the base 10 logarithm of x.\n",
      "```\n",
      "\n",
      "\n",
      "Begin!\n",
      "\n",
      "Input: ```text\n",
      "2**5\n",
      "```\n",
      "\n",
      "Output: <code>var result = pow(2, 5)</code>\n",
      "\n",
      "Input: ```text\n",
      "take the log base 2 of 2\n",
      "```\n",
      "\n",
      "Output: <code>var result = log2(2)</code>\n",
      "\n",
      "Input: [user input]\n",
      "\n",
      "Output:\n"
     ]
    }
   ],
   "source": [
    "print(few_shot_prompt.format_prompt(query=\"[user input]\").to_string())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
